{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# REINFORCE in TensorFlow (3 pts)\n",
    "\n",
    "Just like we did before for Q-learning, this time we'll design a TensorFlow network to learn `CartPole-v0` via policy gradient (REINFORCE).\n",
    "\n",
    "Most of the code in this notebook is taken from approximate Q-learning, so you'll find it more or less familiar and even simpler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "if 'google.colab' in sys.modules:\n",
    "    %tensorflow_version 1.x\n",
    "    \n",
    "    if not os.path.exists('.setup_complete'):\n",
    "        !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
    "        !pip install -q gym[classic_control]==0.18.0\n",
    "        !touch .setup_complete\n",
    "\n",
    "# This code creates a virtual display to draw game images on.\n",
    "# It will have no effect if your machine has a monitor.\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
    "    !bash ../xvfb start\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A caveat: with some versions of `pyglet`, the following cell may crash with `NameError: name 'base' is not defined`. The corresponding bug report is [here](https://github.com/pyglet/pyglet/issues/134). If you see this error, try restarting the kernel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"CartPole-v0\")\n",
    "\n",
    "# gym compatibility: unwrap TimeLimit\n",
    "if hasattr(env, '_max_episode_steps'):\n",
    "    env = env.env\n",
    "\n",
    "env.reset()\n",
    "n_actions = env.action_space.n\n",
    "state_dim = env.observation_space.shape\n",
    "\n",
    "plt.imshow(env.render(\"rgb_array\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building the network for REINFORCE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For REINFORCE algorithm, we'll need a model that predicts action probabilities given states.\n",
    "\n",
    "For numerical stability, please __do not include the softmax layer into your network architecture__.\n",
    "We'll use softmax or log-softmax where appropriate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "sess = tf.InteractiveSession()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create input variables. We only need <s, a, r> for REINFORCE\n",
    "ph_states = tf.placeholder('float32', (None,) + state_dim, name=\"states\")\n",
    "ph_actions = tf.placeholder('int32', name=\"action_ids\")\n",
    "ph_cumulative_rewards = tf.placeholder('float32', name=\"cumulative_returns\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "\n",
    "<YOUR CODE: define network graph using raw TF, Keras, or any other library you prefer>\n",
    "\n",
    "logits = <YOUR CODE: symbolic outputs of your network _before_ softmax>\n",
    "\n",
    "policy = tf.nn.softmax(logits)\n",
    "log_policy = tf.nn.log_softmax(logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model parameters\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_probs(states):\n",
    "    \"\"\" \n",
    "    Predict action probabilities given states.\n",
    "    :param states: numpy array of shape [batch, state_shape]\n",
    "    :returns: numpy array of shape [batch, n_actions]\n",
    "    \"\"\"\n",
    "    return policy.eval({ph_states: [states]})[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Play the game\n",
    "\n",
    "We can now use our newly built agent to play the game."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_session(env, t_max=1000):\n",
    "    \"\"\" \n",
    "    Play a full session with REINFORCE agent.\n",
    "    Returns sequences of states, actions, and rewards.\n",
    "    \"\"\"\n",
    "    # arrays to record session\n",
    "    states, actions, rewards = [], [], []\n",
    "    s = env.reset()\n",
    "\n",
    "    for t in range(t_max):\n",
    "        # action probabilities array aka pi(a|s)\n",
    "        action_probs = predict_probs(s)\n",
    "\n",
    "        # Sample action with given probabilities.\n",
    "        a = <YOUR CODE>\n",
    "        new_s, r, done, info = env.step(a)\n",
    "\n",
    "        # record session history to train later\n",
    "        states.append(s)\n",
    "        actions.append(a)\n",
    "        rewards.append(r)\n",
    "\n",
    "        s = new_s\n",
    "        if done:\n",
    "            break\n",
    "\n",
    "    return states, actions, rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test it\n",
    "states, actions, rewards = generate_session(env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing cumulative rewards\n",
    "\n",
    "$$\n",
    "\\begin{align*}\n",
    "G_t &= r_t + \\gamma r_{t + 1} + \\gamma^2 r_{t + 2} + \\ldots \\\\\n",
    "&= \\sum_{i = t}^T \\gamma^{i - t} r_i \\\\\n",
    "&= r_t + \\gamma * G_{t + 1}\n",
    "\\end{align*}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cumulative_rewards(rewards,  # rewards at each step\n",
    "                           gamma=0.99  # discount for reward\n",
    "                           ):\n",
    "    \"\"\"\n",
    "    Take a list of immediate rewards r(s,a) for the whole session \n",
    "    and compute cumulative returns (a.k.a. G(s,a) in Sutton '16).\n",
    "    \n",
    "    G_t = r_t + gamma*r_{t+1} + gamma^2*r_{t+2} + ...\n",
    "\n",
    "    A simple way to compute cumulative rewards is to iterate from the last\n",
    "    to the first timestep and compute G_t = r_t + gamma*G_{t+1} recurrently\n",
    "\n",
    "    You must return an array/list of cumulative rewards with as many elements as in the initial rewards.\n",
    "    \"\"\"\n",
    "    <YOUR CODE>\n",
    "    return <YOUR CODE: array of cumulative rewards>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(get_cumulative_rewards(range(100))) == 100\n",
    "assert np.allclose(\n",
    "    get_cumulative_rewards([0, 0, 1, 0, 0, 1, 0], gamma=0.9),\n",
    "    [1.40049, 1.5561, 1.729, 0.81, 0.9, 1.0, 0.0])\n",
    "assert np.allclose(\n",
    "    get_cumulative_rewards([0, 0, 1, -2, 3, -4, 0], gamma=0.5),\n",
    "    [0.0625, 0.125, 0.25, -1.5, 1.0, -4.0, 0.0])\n",
    "assert np.allclose(\n",
    "    get_cumulative_rewards([0, 0, 1, 2, 3, 4, 0], gamma=0),\n",
    "    [0, 0, 1, 2, 3, 4, 0])\n",
    "print(\"looks good!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Loss function and updates\n",
    "\n",
    "We now need to define objective and update over policy gradient.\n",
    "\n",
    "Our objective function is\n",
    "\n",
    "$$ J \\approx  { 1 \\over N } \\sum_{s_i,a_i} G(s_i,a_i) $$\n",
    "\n",
    "REINFORCE defines a way to compute the gradient of the expected reward with respect to policy parameters. The formula is as follows:\n",
    "\n",
    "$$ \\nabla_\\theta \\hat J(\\theta) \\approx { 1 \\over N } \\sum_{s_i, a_i} \\nabla_\\theta \\log \\pi_\\theta (a_i \\mid s_i) \\cdot G_t(s_i, a_i) $$\n",
    "\n",
    "We can abuse Tensorflow's capabilities for automatic differentiation by defining our objective function as follows:\n",
    "\n",
    "$$ \\hat J(\\theta) \\approx { 1 \\over N } \\sum_{s_i, a_i} \\log \\pi_\\theta (a_i \\mid s_i) \\cdot G_t(s_i, a_i) $$\n",
    "\n",
    "When you compute the gradient of that function with respect to network weights $\\theta$, it will become exactly the policy gradient."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This code selects the log-probabilities (log pi(a_i|s_i)) for those actions that were actually played.\n",
    "indices = tf.stack([tf.range(tf.shape(log_policy)[0]), ph_actions], axis=-1)\n",
    "log_policy_for_actions = tf.gather_nd(log_policy, indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Policy objective as in the last formula. Please use reduce_mean, not reduce_sum.\n",
    "# You may use log_policy_for_actions to get log probabilities for actions taken.\n",
    "# Also recall that we defined ph_cumulative_rewards earlier.\n",
    "\n",
    "J = <YOUR CODE>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a reminder, for a discrete probability distribution (like the one our policy outputs), entropy is defined as:\n",
    "\n",
    "$$ \\operatorname{entropy}(p) = -\\sum_{i = 1}^n p_i \\cdot \\log p_i $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Entropy regularization. If you don't add it, the policy will quickly deteriorate to\n",
    "# being deterministic, harming exploration.\n",
    "\n",
    "entropy = <YOUR CODE: compute entropy. Do not forget the sign!>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Maximizing X is the same as minimizing -X, hence the sign.\n",
    "loss = -(J + 0.1 * entropy)\n",
    "\n",
    "update = tf.train.AdamOptimizer().minimize(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_on_session(states, actions, rewards, t_max=1000):\n",
    "    \"\"\"given full session, trains agent with policy gradient\"\"\"\n",
    "    cumulative_rewards = get_cumulative_rewards(rewards)\n",
    "    update.run({\n",
    "        ph_states: states,\n",
    "        ph_actions: actions,\n",
    "        ph_cumulative_rewards: cumulative_rewards,\n",
    "    })\n",
    "    return sum(rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize optimizer parameters\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The actual training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(100):\n",
    "    rewards = [train_on_session(*generate_session(env)) for _ in range(100)]  # generate new sessions\n",
    "\n",
    "    print(\"mean reward: %.3f\" % (np.mean(rewards)))\n",
    "\n",
    "    if np.mean(rewards) > 300:\n",
    "        print(\"You Win!\")  # but you can train even further\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Results & video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Record sessions\n",
    "\n",
    "import gym.wrappers\n",
    "\n",
    "with gym.wrappers.Monitor(gym.make(\"CartPole-v0\"), directory=\"videos\", force=True) as env_monitor:\n",
    "    sessions = [generate_session(env_monitor) for _ in range(100)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show video. This may not work in some setups. If it doesn't\n",
    "# work for you, you can download the videos and view them locally.\n",
    "\n",
    "from pathlib import Path\n",
    "from base64 import b64encode\n",
    "from IPython.display import HTML\n",
    "\n",
    "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n",
    "video_path = video_paths[-1]  # You can also try other indices\n",
    "\n",
    "if 'google.colab' in sys.modules:\n",
    "    # https://stackoverflow.com/a/57378660/1214547\n",
    "    with video_path.open('rb') as fp:\n",
    "        mp4 = fp.read()\n",
    "    data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n",
    "else:\n",
    "    data_url = str(video_path)\n",
    "\n",
    "HTML(\"\"\"\n",
    "<video width=\"640\" height=\"480\" controls>\n",
    "  <source src=\"{}\" type=\"video/mp4\">\n",
    "</video>\n",
    "\"\"\".format(data_url))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
